import abc
from typing import Optional

from centralized_verification.shields.shield import ShieldResult
from centralized_verification.utils import TrainingProgress


class MultiAgent(abc.ABC):
    @abc.abstractmethod
    def num_agents(self) -> int:
        pass

    @abc.abstractmethod
    def get_joint_action(self, joint_observation, step_num: Optional[TrainingProgress]):
        """
        :param step_num: Will be None if in testing
        """
        pass

    @abc.abstractmethod
    def get_log_dict(self):
        pass


class MultiAgentLearner(MultiAgent, abc.ABC):
    @abc.abstractmethod
    def observe_transition(self, joint_obs, shield_result: ShieldResult, joint_next_obs, joint_rew, done, step_num,
                           training_progress):
        pass

    @abc.abstractmethod
    def state_dict(self):
        pass

    @abc.abstractmethod
    def load_state_dict(self, state_dict):
        pass
